## Get distribution 

function get_stationary_distrib(;initial_distrib,entry_distrib,exit_proba,crit = 1e-12, max_iter = 1000, verbose = false)

    # initialize objects 
    diff = Inf
    iter = 1
    mass_exit = 0.0 

    # make sure to normalize entry distribution
    entry_distrib = entry_distrib ./ sum(entry_distrib)

    prev_distrib = copy(initial_distrib)
    new_distrib = copy(initial_distrib)

    # start loop of iterating on distribution
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff)
		end 
        
        # Step 1: Compute survivors from last period 
        survivor_distrib = prev_distrib .* (1.0 .- exit_proba)
        mass_exit = sum(prev_distrib .* exit_proba) # will add this back later! 

        # Step 2: Update their productivity 
        survivor_distrib_new = (transition_z' * survivor_distrib)

        # Step 3: Add new entrants (come from unconditional distribution)
        entrant_distrib = entry_distrib .* mass_exit # ensure that entry and exit always balance 
        
        # Step 4: put all together 
        new_distrib .= survivor_distrib_new .+ entrant_distrib
    
        # Step 5: update
        diff = maximum(abs.(new_distrib .- prev_distrib))
        iter = iter + 1 
        prev_distrib = copy(new_distrib)
    end 
    
    # return SS distribution 
    return (SS_distrib = new_distrib, mass_entry = mass_exit)
end

# this function is still incomplete!! 
function find_equilibrium_baseline(; param, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500)

    ### STILL INCOMPLETE!! 

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    initial_distrib = z_star_uncond_distrib 
    entrant_distrib = z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and Y: ", guess_Y_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff, " and w: ", w_new)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf(w = w_new, Y = guess_Y_new, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*entry_cost)/(w_new*entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf(w = eqlbm_w, Y = guess_Y_new, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib(
            initial_distrib = initial_distrib, 
            entry_distrib = entrant_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-vat_tax) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
        new_mass = aggregate_labor_supply / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guess 

        # check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        # diff_Y = (total_Y - guess_Y_new)/guess_Y_new
        diff_Y = (guess_Y_new - total_Y)/total_Y

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = abs(diff_Y)
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf(w = guess_w_new, Y = guess_Y_new, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib(
        initial_distrib = initial_distrib, 
        entry_distrib = entrant_distrib, 
        exit_proba = VF_results.exit_proba)
    
    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-vat_tax) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
    final_mass = aggregate_labor_supply / total_labor_demand_normalized

    # compute value added output 
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA)
end

function find_equilibrium_cf(; guess_w, guess_Y, type = "NC", tax = nothing, subsidy = nothing, param, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    initial_distrib = z_star_uncond_distrib 
    entrant_distrib = z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and Y: ", guess_Y_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff, " and w: ", w_new)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf(w = w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*entry_cost)/(w_new*entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf(w = eqlbm_w, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib(
            initial_distrib = initial_distrib, 
            entry_distrib = entrant_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
        new_mass = aggregate_labor_supply / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guess 

        # check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new
        # diff_Y = (guess_Y_new - total_Y)/total_Y

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = abs(diff_Y)
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf(w = guess_w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib(
        initial_distrib = initial_distrib, 
        entry_distrib = entrant_distrib, 
        exit_proba = VF_results.exit_proba)
    
    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
    final_mass = aggregate_labor_supply / total_labor_demand_normalized

    # compute value added output (may still include rent-seeking!!)
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA)
end 

# assume instead that entry cost is in terms of final goods 
function find_equilibrium_cf_goods(; guess_w, guess_Y, aggr_L, type = "NC", tax = nothing, subsidy = nothing, param, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    initial_distrib = z_star_uncond_distrib 
    entrant_distrib = z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and Y: ", guess_Y_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff, " and w: ", w_new)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf(w = w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value (could leave w_new in or take out)
            diff_w = (entry_value - w_new*entry_cost)/(w_new*entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf(w = eqlbm_w, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib(
            initial_distrib = initial_distrib, 
            entry_distrib = entrant_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) 
        new_mass = aggr_L / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guess 

        # check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new
        # diff_Y = (guess_Y_new - total_Y)/total_Y

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = abs(diff_Y)
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf(w = guess_w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib(
        initial_distrib = initial_distrib, 
        entry_distrib = entrant_distrib, 
        exit_proba = VF_results.exit_proba)
    
    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib)
    final_mass = aggr_L / total_labor_demand_normalized

    # compute value added output (may still include rent-seeking!!)
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA)
end 

# find equilibrium with tax rate 
function find_equilibrium_cf_tax(; guess_w, guess_Y, guess_tax, param, update_param_w = 0.5, update_param_Y = 0.8, update_param_tax = 0.5, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    guess_tax_new = copy(guess_tax)
    initial_distrib = z_star_uncond_distrib 
    entrant_distrib = z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, ", Y: ", guess_Y_new, " and tax: ", guess_tax_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff, " and w: ", w_new)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf(w = w_new, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*entry_cost)/(w_new*entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf(w = eqlbm_w, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib(
            initial_distrib = initial_distrib, 
            entry_distrib = entrant_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-guess_tax_new) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
        new_mass = aggregate_labor_supply / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guesses

        ## check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        # diff_Y = (guess_Y_new - total_Y)/total_Y
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)

        ## check tax revenue 
        cit_revenue = sum( VF_results.exp_profits .* distrib_results.SS_distrib) * (profit_tax/(1-profit_tax)) * new_mass
        vat_revenue = sum( guess_tax_new .* (1-gamma_gross) .* VF_results.revenue .* distrib_results.SS_distrib) * new_mass
        total_tax_revenue = cit_revenue + vat_revenue

        diff_tax = (total_tax_revenue - baseline_govt_spending)/baseline_govt_spending 

        println("Difference tax revenue is: ", diff_tax) 

        # update tax guess
        guess_tax_old = copy(guess_tax_new)
        guess_tax_new = guess_tax_old * (1.0 - diff_tax*update_param_tax)

        # ## update tax guess (first compute implied rate keeping rest fixed, then update)
        # tax_term = ( baseline_govt_spending / sum( VF_results.revenue .* distrib_results.SS_distrib .* new_mass ) )
        # implied_tax = tax_term / (1.0 + tax_term) 

        # guess_tax_old = copy(guess_tax_new)
        # guess_tax_new = (1-update_param_tax)*guess_tax_old + update_param_tax*implied_tax
                
        # update diff & iterate 
        diff = maximum(abs.([diff_Y,diff_tax]))
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf(w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib(
        initial_distrib = initial_distrib, 
        entry_distrib = entrant_distrib, 
        exit_proba = VF_results.exit_proba)
    
    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-guess_tax_new) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib) + distrib_results.mass_entry*entry_cost 
    final_mass = aggregate_labor_supply / total_labor_demand_normalized

    # compute value added output 
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    ## get total tax revenue 
    cit_revenue = sum( VF_results.exp_profits .* distrib_results.SS_distrib) * (profit_tax/(1-profit_tax)) * final_mass
    vat_revenue = sum( guess_tax_new .* (1-gamma_gross) .* VF_results.revenue .* distrib_results.SS_distrib) * final_mass
    total_tax_revenue = cit_revenue + vat_revenue

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA, total_tax_revenue = total_tax_revenue)
end 

# find equilibrium with tax rate 
function find_equilibrium_cf_tax_goods(; guess_w, guess_Y, guess_tax, aggr_L, aggr_T, param, update_param_w = 0.5, update_param_Y = 0.8, update_param_tax = 0.5, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    guess_tax_new = copy(guess_tax)
    initial_distrib = z_star_uncond_distrib 
    entrant_distrib = z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, ", Y: ", guess_Y_new, " and tax: ", guess_tax_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff, " and w: ", w_new)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf(w = w_new, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*entry_cost)/(w_new*entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf(w = eqlbm_w, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib(
            initial_distrib = initial_distrib, 
            entry_distrib = entrant_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-guess_tax_new) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) 
        new_mass = aggr_L / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guesses

        ## check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        # diff_Y = (guess_Y_new - total_Y)/total_Y
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)

        ## check tax revenue 

        # construct new total tax revenue
        cit_revenue = sum( VF_results.exp_profits .* distrib_results.SS_distrib) * (profit_tax/(1-profit_tax)) * new_mass
        vat_revenue = sum( guess_tax_new .* (1-gamma_gross) .* VF_results.revenue .* distrib_results.SS_distrib) * new_mass
        new_T = cit_revenue + vat_revenue

        # compute difference 
        diff_tax = (new_T - aggr_T)/aggr_T

        println("Difference tax revenue is: ", diff_tax) 

        # update tax guess
        guess_tax_old = copy(guess_tax_new)
        guess_tax_new = guess_tax_old * (1.0 - diff_tax*update_param_tax)

        # update diff & iterate 
        diff = maximum(abs.([diff_Y,diff_tax]))
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf(w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib(
        initial_distrib = initial_distrib, 
        entry_distrib = entrant_distrib, 
        exit_proba = VF_results.exit_proba)
    
    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-guess_tax_new) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib)
    final_mass = aggr_L / total_labor_demand_normalized

    # compute value added output 
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    ## get total tax revenue 
    cit_revenue = sum( VF_results.exp_profits .* distrib_results.SS_distrib) * (profit_tax/(1-profit_tax)) * final_mass
    vat_revenue = sum( guess_tax_new .* (1-gamma_gross) .* VF_results.revenue .* distrib_results.SS_distrib) * final_mass
    total_tax_revenue = cit_revenue + vat_revenue

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA, total_tax_revenue = total_tax_revenue)
end 

# find equilibrium with tax rate 
function find_equilibrium_cf_tax_goods_fixdistrib(; guess_w, guess_Y, guess_tax, distrib, mass, aggr_L, aggr_T, param, type = "NC", update_param_w = 0.5, update_param_Y = 0.8, update_param_tax = 0.5, crit = 1e-6, max_iter = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    guess_tax_new = copy(guess_tax)
    # guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    # start loop of iterating on guesses for (w,Y,tax)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, ", Y: ", guess_Y_new, " and tax: ", guess_tax_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # get optimal choices (don't even need to construct VF) 
        if type == "NC"
            profit_results = compute_profits_NC(w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, param = param, z_grid = z_grid_baseline)  
            # save results objects 
            exp_profits = profit_results.profits 
            revenue = profit_results.revenue 
            output_aggr = profit_results.output_aggr
        
        elseif type == "C" 
            profit_results = compute_profits_grid_cf(w = guess_w_new, Y = guess_Y_new, tax = guess_tax_new, n_eps = n_eps, param = param, z_grid = z_grid_baseline)
            # save results objects 
            exp_profits = profit_results.expected_profits
            revenue = profit_results.expected_revenue
            output_aggr = profit_results.expected_revenue_output
        
        elseif type == "C-fix"
            if subsidy == nothing 
                println("Error: Need to specify subsidy rate")
            end 
            profit_results = compute_profits_grid_cf_fix_subsidy(w = guess_w_new, Y = guess_Y_new, subsidy = subsidy, tax = guess_tax_new, param = param, z_grid = z_grid_baseline)
            # save results objects 
            exp_profits = profit_results.expected_profits
            revenue = profit_results.expected_revenue
            output_aggr = profit_results.expected_revenue_output
        
        else 
            println("No other options than type == {C,NC,C-fix}")
        end

        # check labor market clearing 
        total_labor_demand = sum( (1-guess_tax_new) .* beta_gross .* (revenue ./ guess_w_new) .* distrib) * mass 
        diff_L = (total_labor_demand - aggr_L)/aggr_L 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)

        # check output 
        total_Y = sum(output_aggr .* distrib) * mass 
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)

        ## update tax guess! 

        ## check tax revenue 
        cit_revenue = sum( exp_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass
        vat_revenue = sum( guess_tax_new .* (1-gamma_gross) .* revenue .* distrib) * mass
        new_T = cit_revenue + vat_revenue

        # compute difference 
        diff_tax = (new_T - aggr_T)/aggr_T

        println("Difference tax revenue is: ", diff_tax) 

        # update tax guess
        guess_tax_old = copy(guess_tax_new)
        guess_tax_new = guess_tax_old * (1.0 - diff_tax*update_param_tax)
                
        # update diff & iterate 
        diff = maximum([abs(diff_L),abs(diff_Y),abs(diff_tax)])
        iter = iter + 1 
    end 

    ### compute aggregates 

    # Find Y corresponding to w 
    w_eq = copy(guess_w_new) 
    Y_eq = copy(guess_Y_new)
    tax_eq = copy(guess_tax_new)

    ### Find optimal choices conditional on w & Y 

    # get optimal choices (don't even need to construct VF) 
    if type == "NC"
        profit_results = compute_profits_NC(w = w_eq, Y = Y_eq, tax = tax_eq, param = param, z_grid = z_grid_baseline)  
        # save results objects 
        exp_profits = profit_results.profits 
        revenue = profit_results.revenue 
        output_aggr = profit_results.output_aggr
    
    elseif type == "C" 
        profit_results = compute_profits_grid_cf(w = w_eq, Y = Y_eq, tax = tax_eq, n_eps = n_eps, param = param, z_grid = z_grid_baseline)
        # save results objects 
        exp_profits = profit_results.expected_profits
        revenue = profit_results.expected_revenue
        output_aggr = profit_results.expected_revenue_output
    
    elseif type == "C-fix"
        if subsidy == nothing 
            println("Error: Need to specify subsidy rate")
        end 
        profit_results = compute_profits_grid_cf_fix_subsidy(w = w_eq, Y = Y_eq, subsidy = subsidy, tax = tax_eq, param = param, z_grid = z_grid_baseline)
        # save results objects 
        exp_profits = profit_results.expected_profits
        revenue = profit_results.expected_revenue
        output_aggr = profit_results.expected_revenue_output
    
    else 
        println("No other options than type == {C,NC,C-fix}")
    end

    ### compute final objects 

    # get total labor 
    total_labor_demand = sum( (1-guess_tax_new) .* beta_gross .* (revenue ./ guess_w_new) .* distrib) * mass 

    # get total Y 
    total_Y = sum(output_aggr .* distrib) * mass 

    # get total tax revenue 
    cit_revenue = sum( exp_profits .* distrib) * (profit_tax/(1-profit_tax)) * mass
    vat_revenue = sum( tax_eq .* (1-gamma_gross) .* revenue .* distrib) * mass
    final_T = cit_revenue + vat_revenue

    # compute value added output 
    total_intermediates = sum((gamma_gross .* revenue) .* distrib) * mass
    output_VA = total_Y - total_intermediates

    # return objects
    return (w = w_eq, Y = Y_eq, tax = tax_eq, output_VA = output_VA, total_tax_revenue = final_T)
end 

# doesnt yet allow for path of vat_tax 
function get_distrib_path_cf(; starting_distribution, VF_path, entry_distrib, wage_path)

    ## Initialize objects ## 
    length_transition = length(wage_path)
    distribution_path = repeat([starting_distribution], length_transition)
    mass_entrants = zeros(Float64, length_transition)
    prev_distrib = copy(starting_distribution)
    new_distrib = copy(starting_distribution)
    mass_exit_path = zeros(Float64, length_transition)
    mass_entrant_path = zeros(Float64, length_transition)

    ## Then iterate ## 

    for year in 1:1:length_transition

        ### Step 1: Compute survivors from last period ###
        survivor_distrib = prev_distrib .* (1.0 .- VF_path[year].exit_proba)
        mass_exit_path[year] = sum(prev_distrib .* VF_path[year].exit_proba)

        ### Step 2: Update their productivity ###
        survivor_distrib_new = (transition_z' * survivor_distrib)

        ### Step 3: Find mass of new entrants in line with labor market clearing 

        # Substep a) Get labor demand of survivors 
        total_labor_demand_survivors = sum( (1-vat_tax) .* beta_gross .* (VF_path[year].within_period_objects.revenue ./ wage_path[year]) .* survivor_distrib_new)
        missing_labor_demand = aggregate_labor_supply - total_labor_demand_survivors 
        
        # Substep b) Get mass of entrants that is in line with labor market clearing 
        mass_entrant_path[year] = missing_labor_demand/(sum((1-vat_tax) .* beta_gross .* (VF_path[year].within_period_objects.revenue ./ wage_path[year]) .* entry_distrib) + entry_cost)
        entrant_distrib = entry_distrib .* mass_entrant_path[year] 

        ### Step 4: Put all together 
        new_distrib .= survivor_distrib_new .+ entrant_distrib
        distribution_path[year] = copy(new_distrib) # save in distribution path 

        # Update recursion 
        prev_distrib = copy(new_distrib)
    end 

    return (distribution_path = distribution_path, mass_entrant_path = mass_entrant_path, mass_exit_path = mass_exit_path)
end 

function find_transition_path_cf(; guess_w_path, guess_Y_path, param, VF_results_end, profits_object_end, starting_distribution, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # Inputs: Initial distribution

    length_transition = length(guess_w_path)

    ### Loop to iterate over guesses of price paths 

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w_path)
    guess_Y_new = copy(guess_Y_path)
    distribution_path = repeat([starting_distribution], length_transition)
    mass_exit_path = zeros(Float64, length_transition)
    mass_entrant_path = zeros(Float64, length_transition)    
    VF_path_end = (VF = VF_results_end.VF, exit_proba = VF_results_end.exit_proba, exp_VF = VF_results_end.exp_VF, 
				   exp_fcost = VF_results_end.exp_fcost, within_period_objects = profits_object_end) 
	VF_results = repeat([VF_path_end],length_transition)

    # save updates 
    guess_Y_path = repeat([guess_Y_new],max_iter)
    guess_w_path = repeat([guess_w_new],max_iter)

    # start loop of iterating on guesses for paths of (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with max difference: ",diff)
		end 

        # make sure to save path (for plotting later)
        guess_Y_path[iter] = copy(guess_Y_new)
        guess_w_path[iter] = copy(guess_w_new)

        ### Step 1: find wage path that is in line with free entry condition at every point 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_path_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff_w)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_transition(
                w_path = w_path_new, Y_path = guess_Y_new, param = param, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)
                
            # Step 2: compute expected value of entry at every point 
            entry_value = zeros(length_transition)
            diff_w_path = zeros(length_transition)
            for i in 1:1:length_transition
                entry_value[i] = sum(z_star_uncond_distrib .* VF_results_w[i].VF)
                diff_w_path[i] = (entry_value[i] - w_path_new[i]*entry_cost)/(w_path_new[i]*entry_cost)
            end 

            # update difference 
            diff_w = maximum(abs.(diff_w_path))

            # update wage guess 
            w_path_old = copy(w_path_new) 
            w_path_new = w_path_old .* (1.0 .+ diff_w_path .* update_param_w)

            # update iteration 
            iter_w = iter_w + 1 
        end 

        # save equilibrium wage 
        eqlbm_w_path = copy(w_path_new)
        println("Found wage path in line with free entry.")

        # get new VF path and corresponding objects 
        VF_results = get_VF_transition(
                w_path = eqlbm_w_path, Y_path = guess_Y_new, param = param, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)

        println("Solved for VF along transition.")
        
        ### Step 2: Find distribution path & path for mass of firms 

        # find paths 
        distribution_path_results = get_distrib_path_cf(
            starting_distribution = starting_distribution, 
            VF_path = VF_results,
            entry_distrib = z_star_uncond_distrib, 
            wage_path = eqlbm_w_path) 
        distribution_path .= distribution_path_results.distribution_path
        mass_exit_path .= distribution_path_results.mass_exit_path
        mass_entrant_path .= distribution_path_results.mass_entrant_path

        println("Solved for distribution path along transition.")

        ### Step 3: Compute aggregates & update guess 

        total_Y_path = zeros(Float64, length_transition)
        diff_Y_path = zeros(Float64, length_transition)

        for year in 1:1:length_transition

            # find output 
            total_Y_path[year] = sum(VF_results[year].within_period_objects.output_aggr .* distribution_path[year])

            # get difference 
            diff_Y_path[year] = (total_Y_path[year] - guess_Y_new[year])/guess_Y_new[year]

        end

        # get updated diff 
        diff = maximum(abs.(diff_Y_path))
        
        println("Difference Y is: ", diff, " at position: ", argmax(abs.(diff_Y_path)), " with mass at end: ", sum(distribution_path[length_transition]))

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old .* (1.0 .+ (diff_Y_path .* update_param_Y))
                
        # update 
        iter = iter + 1 
        guess_w_new = copy(eqlbm_w_path)
    end 

    # return objects (add more output!!)
    return (w_path = guess_w_new, Y = guess_Y_new, distribution_path = distribution_path, 
            mass_exit_path = mass_exit_path, mass_entrant_path = mass_entrant_path, VF_results = VF_results, 
            guess_Y_path = guess_Y_path, guess_w_path = guess_w_path)
end 

function get_distrib_path_cf_goods(; starting_distribution, VF_path, entry_distrib, wage_path, aggr_L)

    ## Initialize objects ## 
    length_transition = length(wage_path)
    distribution_path = repeat([starting_distribution], length_transition)
    mass_entrants = zeros(Float64, length_transition)
    prev_distrib = copy(starting_distribution)
    new_distrib = copy(starting_distribution)
    mass_exit_path = zeros(Float64, length_transition)
    mass_entrant_path = zeros(Float64, length_transition)

    ## Then iterate ## 

    for year in 1:1:length_transition

        ### Step 1: Compute survivors from last period ###
        survivor_distrib = prev_distrib .* (1.0 .- VF_path[year].exit_proba)
        mass_exit_path[year] = sum(prev_distrib .* VF_path[year].exit_proba)

        ### Step 2: Update their productivity ###
        survivor_distrib_new = (transition_z' * survivor_distrib)

        ### Step 3: Find mass of new entrants in line with labor market clearing 

        # Substep a) Get labor demand of survivors 
        total_labor_demand_survivors = sum( (1-vat_tax) .* beta_gross .* (VF_path[year].within_period_objects.revenue ./ wage_path[year]) .* survivor_distrib_new)
        missing_labor_demand = aggr_L - total_labor_demand_survivors 
        
        # Substep b) Get mass of entrants that is in line with labor market clearing 
        mass_entrant_path[year] = missing_labor_demand/(sum((1-vat_tax) .* beta_gross .* (VF_path[year].within_period_objects.revenue ./ wage_path[year]) .* entry_distrib))
        entrant_distrib = entry_distrib .* mass_entrant_path[year] 

        ### Step 4: Put all together 
        new_distrib .= survivor_distrib_new .+ entrant_distrib
        distribution_path[year] = copy(new_distrib) # save in distribution path 

        # Update recursion 
        prev_distrib = copy(new_distrib)
    end 

    return (distribution_path = distribution_path, mass_entrant_path = mass_entrant_path, mass_exit_path = mass_exit_path)
end 

function find_transition_path_cf_goods(; guess_w_path, guess_Y_path, param, aggr_L, VF_results_end, profits_object_end, starting_distribution, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # Inputs: Initial distribution

    length_transition = length(guess_w_path)

    ### Loop to iterate over guesses of price paths 

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w_path)
    guess_Y_new = copy(guess_Y_path)
    distribution_path = repeat([starting_distribution], length_transition)
    mass_exit_path = zeros(Float64, length_transition)
    mass_entrant_path = zeros(Float64, length_transition)    
    VF_path_end = (VF = VF_results_end.VF, exit_proba = VF_results_end.exit_proba, exp_VF = VF_results_end.exp_VF, 
				   exp_fcost = VF_results_end.exp_fcost, within_period_objects = profits_object_end) 
	VF_results = repeat([VF_path_end],length_transition)

    # save updates 
    guess_Y_path = repeat([guess_Y_new],max_iter)
    guess_w_path = repeat([guess_w_new],max_iter)

    # start loop of iterating on guesses for paths of (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with max difference: ",diff)
		end 

        # make sure to save path (for plotting later)
        guess_Y_path[iter] = copy(guess_Y_new)
        guess_w_path[iter] = copy(guess_w_new)

        ### Step 1: find wage path that is in line with free entry condition at every point 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_path_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff_w)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_transition(
                w_path = w_path_new, Y_path = guess_Y_new, param = param, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)
                
            # Step 2: compute expected value of entry at every point 
            entry_value = zeros(length_transition)
            diff_w_path = zeros(length_transition)
            for i in 1:1:length_transition
                entry_value[i] = sum(z_star_uncond_distrib .* VF_results_w[i].VF)
                diff_w_path[i] = (entry_value[i] - w_path_new[i]*entry_cost)/(w_path_new[i]*entry_cost)
            end 

            # update difference 
            diff_w = maximum(abs.(diff_w_path))

            # update wage guess 
            w_path_old = copy(w_path_new) 
            w_path_new = w_path_old .* (1.0 .+ diff_w_path .* update_param_w)

            # update iteration 
            iter_w = iter_w + 1 
        end 

        # save equilibrium wage 
        eqlbm_w_path = copy(w_path_new)
        println("Found wage path in line with free entry.")

        # get new VF path and corresponding objects 
        VF_results = get_VF_transition(
                w_path = eqlbm_w_path, Y_path = guess_Y_new, param = param, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)

        println("Solved for VF along transition.")
        
        ### Step 2: Find distribution path & path for mass of firms 

        # find paths 
        distribution_path_results = get_distrib_path_cf_goods(
            starting_distribution = starting_distribution, 
            VF_path = VF_results,
            entry_distrib = z_star_uncond_distrib, 
            wage_path = eqlbm_w_path, 
            aggr_L = aggr_L) 
        distribution_path .= distribution_path_results.distribution_path
        mass_exit_path .= distribution_path_results.mass_exit_path
        mass_entrant_path .= distribution_path_results.mass_entrant_path

        println("Solved for distribution path along transition.")

        ### Step 3: Compute aggregates & update guess 

        total_Y_path = zeros(Float64, length_transition)
        diff_Y_path = zeros(Float64, length_transition)

        for year in 1:1:length_transition

            # find output 
            total_Y_path[year] = sum(VF_results[year].within_period_objects.output_aggr .* distribution_path[year])

            # get difference 
            diff_Y_path[year] = (total_Y_path[year] - guess_Y_new[year])/guess_Y_new[year]

        end

        # get updated diff 
        diff = maximum(abs.(diff_Y_path))
        
        println("Difference Y is: ", diff, " at position: ", argmax(abs.(diff_Y_path)), " with mass at end: ", sum(distribution_path[length_transition]))

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old .* (1.0 .+ (diff_Y_path .* update_param_Y))
                
        # update 
        iter = iter + 1 
        guess_w_new = copy(eqlbm_w_path)
    end 

    # return objects (add more output!!)
    return (w_path = guess_w_new, Y = guess_Y_new, distribution_path = distribution_path, 
            mass_exit_path = mass_exit_path, mass_entrant_path = mass_entrant_path, VF_results = VF_results, 
            guess_Y_path = guess_Y_path, guess_w_path = guess_w_path)
end 

# function to solve transition path with finding path of taxes (need to update this!)
function find_transition_path_cf_tax(; guess_w_path, guess_Y_path, param, VF_results_end, profits_object_end, starting_distribution, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # Inputs: Initial distribution

    length_transition = length(guess_w_path)

    ### Loop to iterate over guesses of price paths 

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w_path)
    guess_Y_new = copy(guess_Y_path)
    distribution_path = repeat([starting_distribution], length_transition)
    mass_exit_path = zeros(Float64, length_transition)
    mass_entrant_path = zeros(Float64, length_transition)    
    VF_path_end = (VF = VF_results_end.VF, exit_proba = VF_results_end.exit_proba, exp_VF = VF_results_end.exp_VF, 
				   exp_fcost = VF_results_end.exp_fcost, within_period_objects = profits_object_end) 
	VF_results = repeat([VF_path_end],length_transition)

    # save updates 
    guess_Y_path = repeat([guess_Y_new],max_iter)
    guess_w_path = repeat([guess_w_new],max_iter)

    # start loop of iterating on guesses for paths of (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with max difference: ",diff)
		end 

        # make sure to save path (for plotting later)
        guess_Y_path[iter] = copy(guess_Y_new)
        guess_w_path[iter] = copy(guess_w_new)

        ### Step 1: find wage path that is in line with free entry condition at every point 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_path_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # println("Iteration w ",iter," with difference: ",diff_w)

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_transition(
                w_path = w_path_new, Y_path = guess_Y_new, param = param, 
                #VF_results_end = VF_results_end, profits_object_end = profits_object_end, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)
                
            # Step 2: compute expected value of entry at every point 
            entry_value = zeros(length_transition)
            diff_w_path = zeros(length_transition)
            for i in 1:1:length_transition
                entry_value[i] = sum(z_star_uncond_distrib .* VF_results_w[i].VF)
                diff_w_path[i] = (entry_value[i] - w_path_new[i]*entry_cost)/(w_path_new[i]*entry_cost)
            end 

            # update difference 
            diff_w = maximum(abs.(diff_w_path))

            # update wage guess 
            w_path_old = copy(w_path_new) 
            w_path_new = w_path_old .* (1.0 .+ diff_w_path .* update_param_w)

            # update iteration 
            iter_w = iter_w + 1 
        end 

        # save equilibrium wage 
        eqlbm_w_path = copy(w_path_new)
        println("Found wage path in line with free entry.")

        # get new VF path and corresponding objects 
        VF_results = get_VF_transition(
                w_path = eqlbm_w_path, Y_path = guess_Y_new, param = param, 
                VF_results_end = nothing, profits_object_end = nothing,
                guess_exp_VF = nothing, 
                verbose = false)

        println("Solved for VF along transition.")
        
        ### Step 2: Find distribution path & path for mass of firms 

        # find paths 
        distribution_path_results = get_distrib_path_cf(
            starting_distribution = starting_distribution, 
            VF_path = VF_results,
            entry_distrib = z_star_uncond_distrib, 
            wage_path = eqlbm_w_path) 
        distribution_path .= distribution_path_results.distribution_path
        mass_exit_path .= distribution_path_results.mass_exit_path
        mass_entrant_path .= distribution_path_results.mass_entrant_path

        println("Solved for distribution path along transition.")

        ### Step 3: Compute aggregates & update guess 

        total_Y_path = zeros(Float64, length_transition)
        diff_Y_path = zeros(Float64, length_transition)

        for year in 1:1:length_transition

            # find output 
            total_Y_path[year] = sum(VF_results[year].within_period_objects.output_aggr .* distribution_path[year])

            # get difference 
            diff_Y_path[year] = (guess_Y_new[year] - total_Y_path[year])/total_Y_path[year]

        end

        # get updated diff 
        diff = maximum(abs.(diff_Y_path))
        
        println("Difference Y is: ", diff, " at position: ", argmax(abs.(diff_Y_path)), " with mass at end: ", sum(distribution_path[length_transition]))

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old .* (1.0 .+ (diff_Y_path .* update_param_Y))
                
        # update 
        iter = iter + 1 
        guess_w_new = copy(eqlbm_w_path)
    end 

    # return objects (add more output!!)
    return (w_path = guess_w_new, Y = guess_Y_new, distribution_path = distribution_path, 
            mass_exit_path = mass_exit_path, mass_entrant_path = mass_entrant_path, VF_results = VF_results, 
            guess_Y_path = guess_Y_path, guess_w_path = guess_w_path)
end 

# for finding equilibrium with fix distribution (no EE)
function find_equilibrium_cf_fixdistrib_type(; guess_w, guess_Y, distrib, mass, aggr_L, type = "NC", tax = nothing, param, update_param_w = 0.5, update_param_Y = 0.5, crit = 1e-6, max_iter = 500, verbose = false, n_eps = 500)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for w & Y 
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and w: ", guess_w_new)
		end 

        ### Find optimal choices conditional on w & Y 

        # get optimal choices (don't even need to construct VF) 
        if type == "NC"
            profit_results = compute_profits_NC(w = guess_w_new, Y = guess_Y_new, tax = tax, param = param, z_grid = z_grid_baseline)  
            # save results objects 
            exp_profits = profit_results.profits 
            revenue = profit_results.revenue 
            output_aggr = profit_results.output_aggr
        
        elseif type == "C" 
            profit_results = compute_profits_grid_cf(w = guess_w_new, Y = guess_Y_new, tax = tax, n_eps = n_eps, param = param, z_grid = z_grid_baseline)
            # save results objects 
            exp_profits = profit_results.expected_profits
            revenue = profit_results.expected_revenue
            output_aggr = profit_results.expected_revenue_output
        
        elseif type == "C-fix"
            if subsidy == nothing 
                println("Error: Need to specify subsidy rate")
            end 
            profit_results = compute_profits_grid_cf_fix_subsidy(w = guess_w_new, Y = guess_Y_new, subsidy = subsidy, tax = nothing, param = param, z_grid = z_grid_baseline)
            # save results objects 
            exp_profits = profit_results.expected_profits
            revenue = profit_results.expected_revenue
            output_aggr = profit_results.expected_revenue_output
        
        else 
            println("No other options than type == {C,NC,C-fix}")
        end

        # check labor market clearing 
        total_labor_demand = sum( (1-tax) .* beta_gross .* (revenue ./ guess_w_new) .* distrib) * mass 
        diff_L = (total_labor_demand - aggr_L)/aggr_L 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)

        # check output 
        total_Y = sum(output_aggr .* distrib) * mass 
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = maximum([abs(diff_L),abs(diff_Y)])
        iter = iter + 1 
    end 

    ### compute aggregates 

    # Find Y corresponding to w 
    w_eq = copy(guess_w_new) 
    Y_eq = copy(guess_Y_new)

    ### Find optimal choices conditional on w & Y 

    # get optimal choices (don't even need to construct VF) 
    if type == "NC"
        profit_results = compute_profits_NC(w = w_eq, Y = Y_eq, tax = tax, param = param, z_grid = z_grid_baseline)  
        # save results objects 
        exp_profits = profit_results.profits 
        revenue = profit_results.revenue 
        output_aggr = profit_results.output_aggr
    
    elseif type == "C" 
        profit_results = compute_profits_grid_cf(w = w_eq, Y = Y_eq, tax = tax, n_eps = n_eps, param = param, z_grid = z_grid_baseline)
        # save results objects 
        exp_profits = profit_results.expected_profits
        revenue = profit_results.expected_revenue
        output_aggr = profit_results.expected_revenue_output
    
    elseif type == "C-fix"
        if subsidy == nothing 
            println("Error: Need to specify subsidy rate")
        end 
        profit_results = compute_profits_grid_cf_fix_subsidy(w = w_eq, Y = Y_eq, subsidy = subsidy, tax = nothing, param = param, z_grid = z_grid_baseline)
        # save results objects 
        exp_profits = profit_results.expected_profits
        revenue = profit_results.expected_revenue
        output_aggr = profit_results.expected_revenue_output
    
    else 
        println("No other options than type == {C,NC,C-fix}")
    end

    # compute intermediates 
    total_intermediates = sum((gamma_gross .* revenue) .* distrib) * mass

    # compute subsidies & rent-seeking 
    total_rent_seeking = 0.0 
    total_subsidies = 0.0 
    if type == "C"
        total_rent_seeking = sum(param.proba_c .* profit_results.expected_m_R_C .* distrib) * mass 
        total_subsidies = sum(profit_results.expected_subsidies .* distrib) * mass 
    end 
    output_VA = Y_eq - total_intermediates - total_rent_seeking 

    # return objects
    return (w = w_eq, Y = Y_eq, output_VA = output_VA)
end 

# For decomposing results 
function find_equilibrium_cf_fixdistrib(; guess_w, distrib, mass, aggr_L, tax = nothing, param, update_param_w = 0.5, crit = 1e-6, max_iter = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for w 
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and w: ", guess_w_new)
		end 

        # Find Y corresponding to w 
        guess_Y_new = (guess_w_new*aggr_L)/(beta_gross*(1-tax))

        ### Find optimal choices conditional on w & Y 

        # get optimal choices (don't even need to construct VF) 
        profit_results = compute_profits_NC(w = guess_w_new, Y = guess_Y_new, tax = tax, param = param, z_grid = z_grid_baseline) 

		# save results objects 
		exp_profits = profit_results.profits 
		revenue = profit_results.revenue 
		output_aggr = profit_results.output_aggr

        # check labor market clearing 
        total_labor_demand = sum( (1-tax) .* beta_gross .* (revenue ./ guess_w_new) .* distrib) * mass 
        diff_L = (total_labor_demand - aggr_L)/aggr_L 
        println("Difference Labor demand is: ", diff_L) 

        # update wage guess
        guess_w_old = copy(guess_w_new)
        guess_w_new = guess_w_old * (1.0 + diff_L*update_param_w)
                
        # update diff & iterate 
        diff = abs(diff_L)
        iter = iter + 1 
    end 

    ### compute aggregates 

    # Find Y corresponding to w 
    w_eq = copy(guess_w_new) 
    Y_eq = (w_eq*aggr_L)/(beta_gross*(1-tax))

    ### Find optimal choices conditional on w & Y 

    # get optimal choices (don't even need to construct VF) 
    profit_results = compute_profits_NC(w = w_eq, Y = Y_eq, tax = tax, param = param, z_grid = z_grid_baseline) 

    # save results objects 
    exp_profits = profit_results.profits 
    revenue = profit_results.revenue 
    output_aggr = profit_results.output_aggr

    # compute value added output 
    total_intermediates = sum((gamma_gross .* revenue) .* distrib) * mass
    output_VA = Y_eq - total_intermediates 

    # return objects (add more output!!)
    return (w = w_eq, Y = Y_eq, output_VA = output_VA)
end 

## DRS version 
function get_stationary_distrib_DRS(;initial_distrib,entry_distrib,exit_proba,crit = 1e-12, max_iter = 1000, verbose = false)

    # initialize objects 
    diff = Inf
    iter = 1
    mass_exit = 0.0 

    # make sure to normalize entry distribution
    entry_distrib = entry_distrib ./ sum(entry_distrib)

    prev_distrib = copy(initial_distrib)
    new_distrib = copy(initial_distrib)

    # start loop of iterating on distribution
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff)
		end 
        
        # Step 1: Compute survivors from last period 
        survivor_distrib = prev_distrib .* (1.0 .- exit_proba)
        mass_exit = sum(prev_distrib .* exit_proba) # will add this back later! 

        # Step 2: Update their productivity 
        survivor_distrib_new = (DRS_transition_z' * survivor_distrib)

        # Step 3: Add new entrants (come from unconditional distribution)
        entrant_distrib = entry_distrib .* mass_exit # ensure that entry and exit always balance 
        
        # Step 4: put all together 
        new_distrib .= survivor_distrib_new .+ entrant_distrib
    
        # Step 5: update
        diff = maximum(abs.(new_distrib .- prev_distrib))
        iter = iter + 1 
        prev_distrib = copy(new_distrib)
    end 
    
    # return SS distribution 
    return (SS_distrib = new_distrib, mass_entry = mass_exit)
end 

function find_equilibrium_cf_DRS(; guess_w, guess_Y, type = "NC", tax = nothing, subsidy = nothing, param, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    # initial_distrib = DRS_z_star_uncond_distrib 
    # entrant_distrib = DRS_z_star_uncond_distrib
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and Y: ", guess_Y_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf_DRS(w = w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(DRS_z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*DRS_entry_cost)/(w_new*DRS_entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf_DRS(w = eqlbm_w, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib_DRS(
            initial_distrib = DRS_z_star_uncond_distrib, 
            entry_distrib = DRS_z_star_uncond_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib) + distrib_results.mass_entry*DRS_entry_cost
        new_mass = DRS_aggregate_labor_supply / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guess 

        # check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = abs(diff_Y)
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf_DRS(w = guess_w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib_DRS(
        initial_distrib = DRS_z_star_uncond_distrib, 
        entry_distrib = DRS_z_star_uncond_distrib, 
        exit_proba = VF_results.exit_proba)

    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib) + distrib_results.mass_entry*DRS_entry_cost 
    final_mass = DRS_aggregate_labor_supply / total_labor_demand_normalized

    # compute value added output (may still include rent-seeking!!)
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA)
end 

# get version with entry costs denoted in output goods 
function find_equilibrium_cf_DRS_goods(; guess_w, guess_Y, aggr_L, type = "NC", tax = nothing, subsidy = nothing, param, update_param_w = 0.5, update_param_Y = 0.8, crit = 1e-6, max_iter = 500, max_iter_w = 500, verbose = false)

    # initialize objects 
    diff = Inf 
    iter = 1
    guess_w_new = copy(guess_w)
    guess_Y_new = copy(guess_Y)
    guess_exp_VF = (1/(1-beta))*profits_NC_baseline # start with initial guess (update this in the loop)

    if tax == nothing 
        tax = copy(vat_tax)
    end 

    # start loop of iterating on guesses for (w,Y)
    while diff > crit && iter < max_iter

        # print progress
		if verbose
			println("Iteration ",iter," with difference: ",diff, " and Y: ", guess_Y_new)
		end 

        ### Step 1: find wage that is in line with free entry condition 

        # initialize objects 
        diff_w = Inf 
        iter_w = 1
        w_new = copy(guess_w_new)

        # start loop of iterating on guess for w
        while abs(diff_w) > crit && iter_w < max_iter_w

            # Step 1: get VF given (w,Y) guess 
            VF_results_w = get_VF_cf_DRS(w = w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)

            # Step 2: compute expected value of entry 
            entry_value = sum(DRS_z_star_uncond_distrib .* VF_results_w.VF)

            # Step 3: Update wage based on entry value 
            diff_w = (entry_value - w_new*DRS_entry_cost)/(w_new*DRS_entry_cost)

            # println("Difference free entry is: ", diff_w)

            # update wage guess 
            w_old = copy(w_new) 
            w_new = w_old * (1.0 + diff_w*update_param_w)

            # update rest 
            iter_w = iter_w + 1 
            guess_exp_VF = VF_results_w.exp_VF
        end 

        # save equilibrium wage 
        eqlbm_w = copy(w_new)
        println("Found wage in line with free entry. w = ",eqlbm_w)

        # get new VF & exit probability given (w,Y) guess 
        VF_results = get_VF_cf_DRS(w = eqlbm_w, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
        
        ### Step 2: Find stationary distribution & new equilibrium mass of firms 

        # distribution 
        distrib_results = get_stationary_distrib_DRS(
            initial_distrib = DRS_z_star_uncond_distrib, 
            entry_distrib = DRS_z_star_uncond_distrib, 
            exit_proba = VF_results.exit_proba)

        # find new mass of firms in line with free entry 
        total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ eqlbm_w) .* distrib_results.SS_distrib)
        new_mass = aggr_L / total_labor_demand_normalized
        total_labor_demand = total_labor_demand_normalized * new_mass # labor market clears by construction 

        ### Step 3: Compute aggregates & update guess 

        # check output 
        total_Y = sum(VF_results.output_aggr .* distrib_results.SS_distrib .* new_mass)
        diff_Y = (total_Y - guess_Y_new)/guess_Y_new

        println("Difference Y is: ", diff_Y, " with new mass: ", new_mass)

        # update output guess
        guess_Y_old = copy(guess_Y_new)
        guess_Y_new = guess_Y_old * (1.0 + diff_Y*update_param_Y)
                
        # update diff & iterate 
        diff = abs(diff_Y)
        iter = iter + 1 
        guess_exp_VF = VF_results.exp_VF
        guess_w_new = copy(eqlbm_w)
    end 

    ### compute final objects 

    # VF results 
    VF_results = get_VF_cf_DRS(w = guess_w_new, Y = guess_Y_new, type = type, subsidy = subsidy, tax = tax, param = param, guess_exp_VF = guess_exp_VF)
    
    # stationary distribution 
    distrib_results = get_stationary_distrib_DRS(
        initial_distrib = DRS_z_star_uncond_distrib, 
        entry_distrib = DRS_z_star_uncond_distrib, 
        exit_proba = VF_results.exit_proba)

    # find new mass of firms in line with free entry 
    total_labor_demand_normalized = sum( (1-tax) .* beta_gross .* (VF_results.revenue ./ guess_w_new) .* distrib_results.SS_distrib)
    final_mass = aggr_L / total_labor_demand_normalized

    # compute value added output (may still include rent-seeking!!)
    total_intermediates = sum((gamma_gross .* VF_results.revenue) .* distrib_results.SS_distrib) * final_mass
    output_VA = guess_Y_new - total_intermediates

    # return objects (add more output!!)
    return (w = guess_w_new, Y = guess_Y_new, VF_results = VF_results, distrib_results = distrib_results, mass = final_mass, output_VA = output_VA)
end 